Skip to content

[Experiment] ROCm backend#2300

Open
NripeshN wants to merge 206 commits intoml-explore:mainfrom
NripeshN:rocm-support
Open

[Experiment] ROCm backend#2300
NripeshN wants to merge 206 commits intoml-explore:mainfrom
NripeshN:rocm-support

Conversation

@NripeshN
Copy link
Copy Markdown
Contributor

@NripeshN NripeshN commented Jun 16, 2025

Experiment with ROCm backend.

install MLX with ROCm backend using:

mkdir build && cd build
cmake -DMLX_BUILD_ROCM=ON \
      -DCMAKE_PREFIX_PATH=/opt/rocm \
      -DCMAKE_HIP_ARCHITECTURES="gfx90a;gfx1100" \
      ..
make -j$(nproc)

closes #2556

Inspired by @zcbenz

@NripeshN NripeshN changed the title [Experiment] ROCm backend initial push [Experiment] ROCm backend Jun 16, 2025
@lin72h
Copy link
Copy Markdown

lin72h commented Jun 17, 2025

What an unexpected and amazing surprise! I'm absolutely thrilled.

@NripeshN
Copy link
Copy Markdown
Contributor Author

@awni
What do you think of this PR? Does this have the potential to be merged into main? I can turn this PR from experimental to WIP if so.

@angeloskath
Copy link
Copy Markdown
Member

I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it.

@akshat2602
Copy link
Copy Markdown

I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat.

@countradooku
Copy link
Copy Markdown

Stole my idea :(

@goniz
Copy link
Copy Markdown

goniz commented Jan 22, 2026

How is this even possible for such an awesome PR to be left like this?

Copilot AI review requested due to automatic review settings January 24, 2026 17:08
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.

Changes:

  • Added ROCm backend infrastructure with device management, memory allocation, and stream handling
  • Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
  • Updated build system (CMake) to support ROCm compilation with configurable GPU architectures

Reviewed changes

Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
CMakeLists.txt Added MLX_BUILD_ROCM option and ROCm library detection
mlx/CMakeLists.txt Integrated ROCm backend build configuration
mlx/device.cpp Added ROCm device availability checks
mlx/backend/rocm/*.hip HIP kernel implementations for various operations
mlx/backend/rocm/device.* ROCm device and stream management
mlx/backend/rocm/allocator.* ROCm-specific memory allocator using HIP unified memory
mlx/backend/rocm/worker.* Async task execution worker for stream synchronization
mlx/backend/rocm/utils.* HIP utility functions and error handling
mlx/backend/rocm/jit_module.* JIT compilation support using HIPRTC
mlx/backend/rocm/device/*.hpp Device-side utility functions and type definitions
mlx/backend/rocm/CMakeLists.txt ROCm backend build configuration

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

…ather, scatter, logsumexp, random bits generation, and sorting. Introduce new kernels for efficient computation and integrate with existing ROCm utilities. Update CMake configuration to include new source files and dependencies. Enhance error handling and ensure compatibility with different data types. This commit significantly expands the functionality of the ROCm backend.
@goniz
Copy link
Copy Markdown

goniz commented Jan 24, 2026

👑👑👑

@NripeshN
Copy link
Copy Markdown
Contributor Author

Can anyone run

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .

Replace {based on your GPU} with your GPU architecture

You can run

rocm-smi

to get your GPU information

@goniz
Copy link
Copy Markdown

goniz commented Jan 24, 2026

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .

      -- Configuring done (4.8s)
      CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
      Please set them or make sure they are set and tested correctly in the CMake files:
      /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
         used as include directory in directory /home/goniz/Work/mlx
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      -- Generating done (0.0s)
      CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

@NripeshN
Copy link
Copy Markdown
Contributor Author

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .
     -- Configuring done (4.8s)
     CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
     Please set them or make sure they are set and tested correctly in the CMake files:
     /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
        used as include directory in directory /home/goniz/Work/mlx
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     -- Generating done (0.0s)
     CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅

… string formatting, replacing fmt library usage. Remove unused event.cpp file. Update kernel name generation and parameter formatting for consistency.
@goniz
Copy link
Copy Markdown

goniz commented Jan 25, 2026

  Created wheel for mlx: filename=mlx-0.30.4.dev20260125+cadf18c1-0.editable-cp314-cp314-linux_x86_64.whl size=4722 sha256=72c664adbfc4fb9ec317522a8d83b84f85d599d08bd691d7fec3abfdb6f3a5e9
  Stored in directory: /tmp/pip-ephem-wheel-cache-nt7w6bq0/wheels/8a/63/d1/d7d629a5ff73457822bb71aa527c083674bb19ca314735cd05
Successfully built mlx
Installing collected packages: mlx
Successfully installed mlx-0.30.4.dev20260125+cadf18c1

Now what can I test? 😍

@goniz
Copy link
Copy Markdown

goniz commented Jan 25, 2026

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

@NripeshN
Copy link
Copy Markdown
Contributor Author

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

I forgot to test the Python build my bad, can you try it now?

Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works

Geramy and others added 6 commits March 26, 2026 16:37
Root cause: ensure_row_contiguous_matrix only checked last 2 dimensions.
Arrays from expand_dims (SwitchGLU MoE path) had non-contiguous batch
strides that passed the check but caused OOB when the kernel used flat
pointer arithmetic (x + lhs_idx * M * K).

Fix:
- GatherQMM::eval_gpu: use ensure_row_contiguous (full contiguity check)
  for all inputs, not just ensure_row_contiguous_matrix (last-2-dims)
- Add LHS_B parameter (valid x batch count) to both gather kernels
- Add bounds clamping: lhs_idx < LHS_B, rhs_idx < E
- QuantizedMatmul (non-gather) unchanged — no batch indirection
RMSNorm (called 72x per forward pass):
- Replace rsqrtf() hardware approximation with 1.0f/sqrtf() for IEEE
  compliance (Metal uses precise::rsqrt)
- Match Metal's weight application order: truncate to T between
  normalization and weight multiply (intermediate rounding step)
- Same fix applied to LayerNorm

Sort/ArgSort:
- Add is_sort_floating_v trait that includes __half and hip_bfloat16
  (std::is_floating_point_v is false for these, skipping NaN handling)
- Fix NaN comparison and sentinel values for half types
- Add __half nan_value specialization

SDPA:
- Fix max_score initialization: use Limits<U>::finite_min (-FLT_MAX)
  instead of -1e9f (matches Metal)
- Fix zero-sum normalization edge case

Standalone ops (binary_ops.hpp, unary_ops.hpp):
- Promote __half and hip_bfloat16 through float for Add, Subtract,
  Multiply, Divide (Metal auto-promotes, ROCm doesn't)
- Add float promotion for unary ops with __half inputs

JIT preamble (compiled.cpp):
- Remove redundant float promotion for Add/Subtract/Multiply/Divide
  (already promoted in previous commit, clean up duplicate logic)
The non-uniform-stride batch loop in gemm_and_bias() called rocBLAS
directly (bypassing the naive_gemm wrapper that was patched earlier)
and only handled float32/float64 — bfloat16 and float16 matmuls
silently did nothing, leaving the output buffer uninitialized.

This caused non-deterministic SDPA results for any GQA model (where
n_q_heads != n_kv_heads) at sequence lengths >= 4, with progressively
worse corruption (NaN/Inf at L >= 7). The SDPA fallback decomposition
reshapes Q via unflatten and K/V via expand_dims for GQA broadcasting,
which produces non-uniform batch strides that hit this code path.

Fix: always use naive_gemm_with_offset for the non-uniform-stride
batch loop, matching the approach already used by the single-GEMM
and strided-batched paths.
The supports_sdpa_vector() function listed head_dim=256 as supported,
but the sdpa_vector() dispatch only had cases for D=64, 96, 128.
For D=256, no kernel was launched, leaving the output buffer
uninitialized — causing non-deterministic results for models using
head_dim=256 (e.g. Qwen3-Next) at sequence lengths 1-3.
@Geramy
Copy link
Copy Markdown

Geramy commented Mar 27, 2026

I have a lot of changes to merge in I am testing my port of mlx-swift-lm https://github.com/lemonade-sdk/lemon-mlx-engine against the mlx rocm core, https://github.com/lemonade-sdk/lemon-mlx-core-amd I got qwen3 working. I am working on Qwen3Next right now, its having weird issues. There are tons of problems with the rocm backend that I have traced to "different rounding" causing unstable outputs. But a lot of it is fixed now at least regarding qwen models. Once I get your changes merged into my repo I will then push a PR into yours with my changes. I have made optimizations as well, there are problems with the fallback system when functions in rocBLAS arn't compatible or existent for the architecture.

@Geramy
Copy link
Copy Markdown

Geramy commented Mar 27, 2026

Once I get Qwen3Next working at a reasonable speed I will do the PR.

Geramy added 17 commits March 27, 2026 10:04
Merges goniz/rocm-support-fixes: flash attention kernel, allocator
redesign for integrated GPUs, bfloat16 math overloads, QMV
vectorization, depthwise conv1d, event sync improvements, rocBLAS
solution-index dispatch, and upstream main (CUDA, docs, quantization).

Conflicts resolved preferring upstream for most ROCm backend files,
keeping our SliceUpdate kernel and float-promotion JIT approach.
The gather_qmv_warp_shared_kernel (wave-cooperative, shared memory
tiling, vectorized 4-bit unpacking) was only dispatched for 6-bit and
8-bit quantization. 4-bit fell through to the naive gather_qmv_kernel
(1 thread per output, sequential K loop), which was 18.6x slower.

Add bits==4 to the fast dispatch condition. The kernel already handles
4-bit internally with 8-element vectorized unpacking.

Profiled impact (Qwen3-Next 4-bit MoE):
  gather_qmv_kernel:             5193 μs/call → (removed)
  gather_qmv_warp_shared_kernel: N/A          → 279 μs/call (18.6x)
Key changes for Strix Halo / RDNA 3.5 integrated GPU:

1. raw_ptr(): Use hipStreamSynchronize(nullptr) instead of
   hipDeviceSynchronize() for unified memory buffers. Only waits on
   the default stream instead of all streams. Skips the expensive
   move_to_unified_memory() since integrated GPU memory is already
   CPU-accessible (device==-1).

2. malloc(): Integrated GPU path now goes through rocm_unified_malloc()
   which sets device=-1, so raw_ptr() takes the fast path.

3. rocm_unified_malloc(): Integrated GPUs try hipExtMallocWithFlags
   (fine-grained coherent) first, falling back to hipMallocManaged.

Profiled impact on Qwen3-Next 4-bit MoE:
  Generation: 12.0 tok/s → 18.9 tok/s (58% faster)
  Prompt:     2.5 tok/s → 5.2 tok/s (2x faster)
The noshared QMV kernel reads x from global memory redundantly per
warp (each warp reloads the same x vector). The shared variant caches
x in LDS and is significantly faster for decode-sized (M<=8) shapes.

Disable the alignment-based noshared path selection; always use the
shared variant unless K is tiny. This reduces redundant global memory
traffic for dense quantized projections.
For MoE prefill (M>1) with sorted rhs_indices, consecutive batch
elements map to the same expert. The existing gather_qmv_warp_shared
kernel launches B independent blocks that each load the same expert
weights from global memory — 60-75x redundant weight traffic.

New gather_qmv_prefill_kernel groups batch elements into contiguous
runs of same-expert assignments. Each block handles one (run, row, col)
and iterates over all batch elements in the run, reading weights once.
Grid z-dimension = num_runs (~8-10 unique experts) instead of B (~600).

Supports 4-bit and 8-bit affine quantization with vectorized unpacking
(8 elements per iteration for 4-bit, 4 for 8-bit) and fmaf accumulation.

Profiled impact (Qwen3-Next 4-bit MoE, 40-token prompt):
  Prompt: 1.8 tok/s → 6.1 tok/s (3.4x faster)
  gather_qmv total: 502ms → ~150ms
New gather_qmv_wmma_prefill_kernel uses rocWMMA 16x16x16 bf16→f32
tiles for matrix multiply-accumulate during MoE prefill. Each wave32
handles a 16x16 output tile, dequantizing 4-bit weights into shared
memory and using rocwmma::mma_sync for the reduction.

Enabled for gfx11 (RDNA 3/3.5) and gfx12 (RDNA 4) when M >= 16 and
dimensions are 16-aligned. Falls back to scalar kernel otherwise.
Guarded by ROCM_HAS_WMMA macro so gfx9/gfx10 builds are unaffected.

Also restores hipExtMallocWithFlags as primary allocator for APU
(reverts hipMallocManaged experiment — fine-grained coherent gives
better GPU kernel bandwidth).

Profiled impact (Qwen3-Coder-Next 4-bit, Strix Halo gfx1151):
  Prompt (40 tok): 84 tok/s → 117 tok/s (39% faster)
  Qwen3-8B prompt: 33 tok/s → 44 tok/s (33% faster)
  Generation: unchanged at ~18 tok/s
- Remove M%16 alignment requirement: kernel now bounds-checks rows,
  padding with zero for tile positions beyond M.
- Remove right_sorted_ requirement from prefill dispatch: CPU-side sort
  creates sorted index arrays and output permutation for any index order.
- Add out_perm parameter to both WMMA and scalar prefill kernels to
  scatter results back to original batch positions after sorted dispatch.
- Add <algorithm> and <numeric> includes for std::sort/std::iota.

NOTE: MLX's MoE layer (SwitchGLU) currently expands all tokens to
individual M=1 calls via gather_qmm. The prefill kernels (M>1) will
activate when upstream changes batch tokens per-expert. The 4-bit
fast gather_qmv_warp_shared dispatch handles the current M=1 path.
New gather_qmv_expert_batched_kernel finds expert run boundaries
on-GPU via binary search of sorted rhs_indices. Each block handles
one (expert, column) pair and iterates over all tokens for that expert,
loading weights once per expert.

Dispatch condition: E <= 64 and B/E >= 4 (low expert count with many
tokens per expert). For high-expert models (E=512 like Qwen3-Next),
the warp_shared kernel remains faster since most runs have only 1-4
tokens and the per-block run-finding overhead isn't justified.
hipBLASLt provides architecture-tuned GEMM kernels via Tensile,
typically outperforming rocBLAS for bf16/fp16 on RDNA 3.5 and CDNA.

New hipblaslt_gemm() and hipblaslt_gemm_batched() functions with:
- Per-device handle cache (thread-safe, lazily initialized)
- Algorithm heuristic selection (best-of-1 from hipBLASLt)
- RAII guards for all descriptor types
- Persistent workspace allocation (up to 32MB, grown as needed)
- fp32 accumulation for bf16/fp16 inputs

matmul.cpp tries hipBLASLt first for bf16/fp16, falls back to
rocBLAS silently on failure. Float32/64 GEMMs unchanged.
The dequant+GEMM path in QuantizedMatmul now tries hipBLASLt before
rocBLAS for bf16 GEMMs. hipBLASLt selects architecture-tuned kernels
via heuristic algorithm search, significantly outperforming rocBLAS
once the algorithm cache is warm.

New hipblaslt_gemm_raw() allows calling from inside kernel lambdas
with pre-swapped column-major parameters, matching the rocBLAS pattern.

Warm prompt (Qwen3-Coder-Next 4-bit, Strix Halo):
  80 tok/s → 207 tok/s (2.6x faster)

First-call overhead from algorithm search is amortized by the
application warmup pass.
- hipblaslt_gemm_raw() for calling from inside kernel lambdas with
  pre-swapped col-major params. Used in QMM bf16 dequant+GEMM path.
- Warm prompt: 80→207 tok/s with hipBLASLt algorithm cache primed.

- CommandEncoder graph capture API (begin_capture, end_capture, replay,
  reset_graph) using hipStreamBeginCapture/EndCapture/GraphLaunch.
  Infrastructure for future decode acceleration (18→34 tok/s potential).
  Not yet active due to MLX lazy eval incompatibility with capture mode.
Replace the 5-operation copy chain (2 allocs + 2 hipMemcpyAsync + 1 kernel)
with single-dispatch strided copy kernels for non-contiguous arrays.

New kernels:
- strided_row_copy_kernel: inner-contiguous with outer stride gap (common
  pattern from take/gather_sort). Uses 4-byte word copies when aligned.
- strided_general_copy_kernel: arbitrary strides, shapes/strides passed
  as by-value structs (zero device allocation).

Tiered dispatch in ensure_row_contiguous_matrix:
1. Already contiguous → return (fast path, unchanged)
2. Inner-contiguous outer gap → strided_row_copy_kernel (1 dispatch)
3. General non-contiguous → strided_general_copy_kernel (1 dispatch)
4. ndim > 10 → old contiguous_copy_gpu fallback

Net: each non-contiguous copy drops from 5 GPU operations to 1.
Coarser size buckets for large allocations improve buffer cache hit
rate during LLM decode. Without this, slightly different allocation
sizes (e.g., 1.01MB vs 1.02MB) miss the cache and trigger
hipExtMallocWithFlags at ~7ms each.

Previous: page-aligned (16KB granularity) for all sizes >= 16KB
New: page-aligned for 16KB-1MB, power-of-2 for >= 1MB

Trades up to 2x memory waste for large buffers in exchange for
dramatically fewer cache misses during steady-state decode.
The power-of-2 rounding for >= 1MB allocations caused OOM by doubling
large allocations that exceeded the 2GB device-local VRAM on iGPU.
Reverted to page-aligned (16KB) rounding for all large sizes.

hipExtMallocWithFlags remains the primary path for iGPU (best GPU
bandwidth via fine-grained coherent access). Falls back to
hipMallocManaged for allocations that exceed VRAM capacity,
accessing the full system RAM (126GB on Strix Halo).
@NripeshN
Copy link
Copy Markdown
Contributor Author

I have a lot of changes to merge

I have added you as a collaborator on my fork, you should be able to push changes directly to this branch(should be able to push changes directly to this PR). Again amazing work🚀

ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator
@Geramy
Copy link
Copy Markdown

Geramy commented Mar 30, 2026

We are still 2* slower on almost all things, tps, and pp. But its workable right now and usable actually. I am going to look futher into how the mac eval and allocator works to see if i can get some ideas for optimizing allocation. The allocations if you benchmark and profile it are off the chain, and thanks for the invite!

@Geramy
Copy link
Copy Markdown

Geramy commented Mar 30, 2026

@goniz do you want to test this again? We should be at a point where its working.

@Geramy
Copy link
Copy Markdown

Geramy commented Mar 30, 2026

@NripeshN if you are interested this is a very interesting read.
https://seb-v.github.io/optimization/update/2025/01/20/Fast-GPU-Matrix-multiplication.html
It really exposes amazing ways to optimize kernels for RDNA 3 and up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add ROCm Support for AMD GPUs

8 participants